// © 2021 Qualcomm Innovation Center, Inc. All rights reserved.
//
// SPDX-License-Identifier: BSD-3-Clause

#include <assert.h>
#include <hyptypes.h>

#include <compiler.h>
#include <idle.h>
#include <object.h>
#include <panic.h>
#include <partition.h>
#include <preempt.h>
#include <scheduler.h>
#include <thread.h>
#include <trace.h>

#include <events/thread.h>

#include "event_handlers.h"
#include "thread_arch.h"

const size_t thread_stack_align	       = 16;
const size_t thread_stack_size_default = ((size_t)1U << 12);

static uintptr_t
thread_get_tls_base(thread_t *thread)
{
	size_t offset = 0;
	__asm__("add     %0, %0, :tprel_hi12:current_thread	;"
		"add     %0, %0, :tprel_lo12_nc:current_thread	;"
		: "+r"(offset));
	return (uintptr_t)thread - offset;
}

static noreturn void
thread_arch_main(thread_t *prev)
{
	thread_t *thread = thread_get_self();

	trigger_thread_start_event();

	trigger_thread_context_switch_post_event(prev);
	object_put_thread(prev);

	thread_func_t thread_func =
		trigger_thread_get_entry_fn_event(thread->kind);
	trigger_thread_load_state_event(true);

	if (thread_func != NULL) {
		preempt_enable();
		thread_func(thread->params);
	}

	thread_exit();
}

thread_t *
thread_arch_switch_thread(thread_t *next_thread)
{
	// Note: the old thread must be in X0 so that a switch to a new thread
	// with its PC set to thread_arch_main() will get the old thread as
	// its argument.
	register thread_t *old __asm__("x0") = thread_get_self();

	// The remaining hard-coded registers here are only needed to ensure a
	// correct clobber list below. The union of the clobber list, hard-coded
	// registers and explicitly saved registers (x29, sp and pc) must be the
	// entire integer register state.
	register register_t old_pc __asm__("x1");
	register register_t old_sp __asm__("x2");
	register register_t old_fp __asm__("x3");
	register uintptr_t  old_context __asm__("x4") =
		(uintptr_t)&old->context.pc;
	static_assert(offsetof(thread_t, context.sp) ==
			      offsetof(thread_t, context.pc) +
				      sizeof(next_thread->context.pc),
		      "PC and SP must be adjacent in context");
	static_assert(offsetof(thread_t, context.fp) ==
			      offsetof(thread_t, context.sp) +
				      sizeof(next_thread->context.sp),
		      "SP and FP must be adjacent in context");
	register register_t new_pc __asm__("x5") = next_thread->context.pc;
	register register_t new_sp __asm__("x6") = next_thread->context.sp;
	register register_t new_fp __asm__("x7") = next_thread->context.fp;
	register uintptr_t  new_tls_base __asm__("x8") =
		thread_get_tls_base(next_thread);

	__asm__ volatile("adr	%[old_pc], .Lthread_continue.%=		;"
			 "mov	%[old_sp], sp				;"
			 "mov	%[old_fp], x29				;"
			 "mov   sp, %[new_sp]				;"
			 "mov   x29, %[new_fp]				;"
			 "msr	TPIDR_EL2, %[new_tls_base]		;"
			 "stp	%[old_pc], %[old_sp], [%[old_context]]	;"
			 "str	%[old_fp], [%[old_context], 16]		;"
			 "br	%[new_pc]				;"
			 ".Lthread_continue.%=:				;"
			 : [old] "+r"(old), [old_pc] "=&r"(old_pc),
			   [old_sp] "=&r"(old_sp), [old_fp] "=&r"(old_fp),
			   [old_context] "+r"(old_context),
			   [new_pc] "+r"(new_pc), [new_sp] "+r"(new_sp),
			   [new_fp] "+r"(new_fp),
			   [new_tls_base] "+r"(new_tls_base)
			 : /* This must not have any inputs */
			 : "x9", "x10", "x11", "x12", "x13", "x14", "x15",
			   "x16", "x17", "x18", "x19", "x20", "x21", "x22",
			   "x23", "x24", "x25", "x26", "x27", "x28", "x30",
			   "cc", "memory");

	return old;
}

noreturn void
thread_arch_set_thread(thread_t *thread)
{
	// This should only be called on the idle thread during power-up, which
	// should already be the current thread for TLS. It discards the current
	// execution state.
	assert(thread == thread_get_self());
	assert(thread == idle_thread());

	// Note: the old thread must be in X0 so that a switch to a new thread
	// with its PC set to thread_arch_main() will get the old thread as
	// its argument.
	register thread_t *old __asm__("x0") = thread;

	register_t new_pc = thread->context.pc;
	register_t new_sp = thread->context.sp;
	register_t new_fp = thread->context.fp;

	__asm__ volatile("mov   sp, %[new_sp]			;"
			 "mov   x29, %[new_fp]			;"
			 "br	%[new_pc]			;"
			 : [old] "+r"(old)
			 : [new_pc] "r"(new_pc), [new_sp] "r"(new_sp),
			   [new_fp] "r"(new_fp)
			 : "memory");
	__builtin_unreachable();
}

register_t
thread_freeze(register_t (*fn)(register_t), register_t param,
	      register_t resumed_result)
{
	TRACE(DEBUG, INFO, "thread_freeze start fn: {:#x} param: {:#x}",
	      (uintptr_t)fn, (uintptr_t)param);

	trigger_thread_save_state_event();

	thread_t *thread = thread_get_self();
	assert(thread != NULL);

	register register_t x0 __asm__("x0") = param;

	// The remaining hard-coded registers here are only needed to
	// ensure a correct clobber list below. The union of the clobber
	// list, fixed output registers and explicitly saved registers
	// (x29, sp and pc) must be the entire integer register state.
	register register_t saved_pc __asm__("x1");
	register register_t saved_sp __asm__("x2");
	register uintptr_t  context __asm__("x3") =
		(uintptr_t)&thread->context.pc;
	register register_t (*fn_reg)(register_t) __asm__("x4") = fn;
	register bool is_resuming __asm__("x5");

	static_assert(offsetof(thread_t, context.sp) ==
			      offsetof(thread_t, context.pc) +
				      sizeof(thread->context.pc),
		      "PC and SP must be adjacent in context");
	static_assert(offsetof(thread_t, context.fp) ==
			      offsetof(thread_t, context.sp) +
				      sizeof(thread->context.sp),
		      "SP and FP must be adjacent in context");

	__asm__ volatile("adr	%[saved_pc], .Lthread_freeze.resumed.%=	;"
			 "mov	%[saved_sp], sp				;"
			 "stp	%[saved_pc], %[saved_sp], [%[context]]	;"
			 "str	x29, [%[context], 16]			;"
			 "blr	%[fn_reg]				;"
			 "mov	%[is_resuming], 0			;"
			 "b	.Lthread_freeze.done.%=			;"
			 ".Lthread_freeze.resumed.%=:			;"
			 "mov	%[is_resuming], 1			;"
			 ".Lthread_freeze.done.%=:			;"
			 : [is_resuming] "=%r"(is_resuming),
			   [saved_pc] "=&r"(saved_pc),
			   [saved_sp] "=&r"(saved_sp), [context] "+r"(context),
			   [fn_reg] "+r"(fn_reg), "+r"(x0)
			 : /* This must not have any inputs */
			 : "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13",
			   "x14", "x15", "x16", "x17", "x18", "x19", "x20",
			   "x21", "x22", "x23", "x24", "x25", "x26", "x27",
			   "x28", "x30", "cc", "memory");

	if (is_resuming) {
		x0 = resumed_result;
		trigger_thread_load_state_event(false);

		TRACE(DEBUG, INFO, "thread_freeze resumed: {:#x}", x0);
	} else {
		TRACE(DEBUG, INFO, "thread_freeze returned: {:#x}", x0);
	}

	return x0;
}

noreturn void
thread_reset_stack(void (*fn)(register_t), register_t param)
{
	thread_t *	    thread	     = thread_get_self();
	register register_t x0 __asm__("x0") = param;
	uintptr_t new_sp = (uintptr_t)thread->stack_base + thread->stack_size;

	__asm__ volatile("mov	sp, %[new_sp]	;"
			 "mov	x29, 0		;"
			 "blr	%[new_pc]	;"
			 :
			 : [new_pc] "r"(fn), [new_sp] "r"(new_sp), "r"(x0)
			 : "memory");
	panic("returned to thread_reset_stack()");
}

void
thread_arch_init_context(thread_t *thread)
{
	assert(thread != NULL);

	thread->context.pc = (uintptr_t)thread_arch_main;
	thread->context.sp = (uintptr_t)thread->stack_base + thread->stack_size;
	thread->context.fp = (uintptr_t)0;
}

void
thread_standard_handle_boot_hypervisor_start(void)
{
	thread_t *thread = idle_thread();
	assert(thread != NULL);

	thread->context.pc = (uintptr_t)thread_arch_main;
}
